#!/usr/bin/env python
import sys, os, time, signal, random, traceback, errno
import multiprocessing, multiprocessing.queues
from optparse import OptionParser

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, 'common')))
import utils

r = utils.import_python_driver()

def call_ignore_interrupt(fun):
    while True:
        try:
            return fun()
        except IOError as ex:
            if ex.errno != errno.EINTR:
                raise

def stress_client_proc(options, start_event, exit_event, stat_queue, host_offset, random_seed):
    host_offset = host_offset % len(options["hosts"])
    host = options["hosts"][host_offset][0]
    port = options["hosts"][host_offset][1]
    ops_per_conn = options["ops_per_conn"]

    random.seed(random_seed)

    # Stagger ops_left so not all clients reconnect at once
    ops_done = int(random.random() * ops_per_conn)

    if ops_per_conn == 0:
        loop_cond = lambda: True
    else:
        loop_cond = lambda: ops_done < ops_per_conn

    runner = QueryThrottler(options, stat_queue)
    stat_queue.put("ready")

    start_event.wait()
    while not exit_event.is_set():
        with r.connect(host, port) as conn:
            while loop_cond():
                if exit_event.is_set():
                    break
                runner.send_query(conn)
                ops_done += 1
            ops_done = 0

def spawn_clients(options, start_event, exit_event, stat_queue):
    num_clients = options["clients"]
    client_procs = []
    host_offset = 0

    random_seed = options["seed"]
    if random_seed is None:
        random_seed = random.random()

    print >> sys.stderr, "Random seed used: %f" % random_seed
    random.seed(random_seed)

    for i in xrange(num_clients):
        client_procs.append(multiprocessing.Process(target=stress_client_proc,
                                                    args=(options,
                                                          start_event,
                                                          exit_event,
                                                          stat_queue,
                                                          host_offset,
                                                          random.random())))
        client_procs[-1].start()
        host_offset += 1

    # Wait for ready responses
    for i in xrange(num_clients):
        response = call_ignore_interrupt(stat_queue.get)
        if response != "ready":
            raise RuntimeError("Unexpected response from client: %s" % str(response))

    return client_procs

def stop_clients(exit_event, child_procs_list, num_clients, timeout):
    exit_event.set()
    end_time = time.time()
    child_procs = set(child_procs_list)

    # Check that all processes have exited, allow some time to shut down
    kill_time = end_time + timeout

    failed_procs = []

    for proc in child_procs:
        proc.terminate()

    while True:
        for proc in child_procs.copy():
            if not proc.is_alive():
                if proc.exitcode not in [0, -15]:
                    failed_procs.append(proc)
                child_procs.remove(proc)
        if child_procs:
            break
        if time.time() < kill_time:
            failed_procs += child_procs
            break
        time.sleep(0.1)

    if failed_procs:
        reasons = ["process " + str(child_procs_list.index(proc) + 1) +
                   (" failed with code " + str(proc.exitcode)
                    if proc.exitcode else " timed out")
                   for proc in failed_procs]
        print "Sub-processes failed: " + ", ".join(reasons)
        sys.exit(1)

    return end_time

# Write stats to stderr, so they don't interfere with parsers
def print_stats(stats, start_time, end_time, num_clients):
    duration = end_time - start_time

    print >> sys.stderr, "Duration: %0.3f seconds" % duration

    # Print stats table
    print >> sys.stderr, ""
    print >> sys.stderr, "Operations data: "

    table = [["total", "per sec", "per sec client avg", "avg latency"]]

    if duration != 0.0:
        per_sec = "%0.3f" % (stats["count"] / duration)
        per_sec_per_client = "%0.3f" % (stats["count"] / duration / num_clients)
    else:
        per_sec = "inf"
        per_sec_per_client = "inf"

    if stats["count"] != 0:
        latency = "%0.6f" % (stats["latency"] / stats["count"])
    else:
        latency = "%0.6f" % (0.0)

    table.append([str(stats["count"]), per_sec, per_sec_per_client, latency])

    column_widths = []
    for i in range(len(table[0])):
        column_widths.append(max([len(row[i]) + 2 for row in table]))

    format_str = ("{:<%d}" + ("{:>%d}" * (len(column_widths) - 1))) % tuple(column_widths)

    rql_time_spent = stats["latency"]
    total_client_time = duration * num_clients
    print >> sys.stderr, "Percent time clients spent in ReQL space: %.2f" % (100 * rql_time_spent / total_client_time)

    for row in table:
        print >> sys.stderr, format_str.format(*row)

    # Print errors
    if len(stats["errors"]) != 0:
        print >> sys.stderr, ""
        print >> sys.stderr, "Errors encountered:"
        for error, count in stats["errors"].items():
            print >> sys.stderr, "%s: %s" % (error, count)

def interrupt_handler(signal, frame, exit_event, parent_pid):
    if os.getpid() == parent_pid:
        exit_event.set()

class QueryThrottler:
    def __init__(self, options, stat_queue):
        self.workload = options["workload"]
        self.stat_queue = stat_queue
        if options["ops_per_sec"] != 0:
            self.secs_per_op = float(options["clients"]) / options["ops_per_sec"]
        else:
            self.secs_per_op = 0

        # Time at which to send the next query
        self.next_query_time = None

    def send_query(self, conn):
        if self.secs_per_op != 0:
            if self.next_query_time is None:
                # Desync from other clients by waiting a random amount of time less than one op's duration
                time.sleep(random.random() * self.secs_per_op)
                self.next_query_time = time.time()

            # Sync up with our schedule
            now = time.time()
            time_overdue = now - self.next_query_time

            if time_overdue > 10 * self.secs_per_op:
                # Don't allow us to get more than 10 ops behind or we'll overload if/when the system recovers
                self.next_query_time = now - (10 * self.secs_per_op)
            elif time_overdue >= 0.0:
                self.next_query_time += self.secs_per_op
            else:
                self.next_query_time += self.secs_per_op
                time.sleep(-time_overdue)

        start_time = time.time()
        result = {"timestamp": start_time}
        try:
            result.update(self.workload.run(conn))
        except (r.ReqlError, r.ReqlDriverError) as ex:
            result["errors"] = result.get('errors', []) + [ex.message]
        except (IOError, OSError) as ex:
            if ex.errno != errno.EINTR:
                raise
            result["errors"] = result.get('errors', []) + ["Interrupted system call"]
        result["latency"] = time.time() - start_time

        self.stat_queue.put(result)

# Main loop for the main stress process
def stress_controller(options):
    stat_queue = multiprocessing.queues.SimpleQueue()
    start_event = multiprocessing.Event()
    exit_event = multiprocessing.Event()
    child_procs = []
    op_count = options["op_count"]

    # Register interrupt, now that we're spawning client processes
    parent_pid = os.getpid()
    signal.signal(signal.SIGINT, lambda a, b: interrupt_handler(a, b, exit_event, parent_pid))

    child_procs.extend(spawn_clients(options, start_event, exit_event, stat_queue))

    stats = {"count": 0, "latency": 0.0, "errors": {}}
    start_time = time.time()

    try:
        start_event.set()

        # Collect stats as they come in
        while not exit_event.is_set():
            if not stat_queue.empty():
                stat = call_ignore_interrupt(stat_queue.get)

                # Print new stat
                if not options["quiet"]:
                    print "%0.3f: %0.6f" % (stat["timestamp"], stat["latency"])
                    sys.stdout.flush()

                stats["count"] += 1
                stats["latency"] += stat["latency"]
                for error in stat.get("errors", []):
                    if not options["ignore_errors"]:
                        print "  - %s" % error
                    stats["errors"][error] = stats["errors"].get(error, 0) + 1

                if op_count != 0 and stats["count"] >= op_count:
                    exit_event.set()

            else:
                if options["duration"] != 0:
                    if time.time() - start_time > options["duration"]:
                        exit_event.set()
                time.sleep(0.1)

    except:
        traceback.print_exc()
        sys.exit(1)
    finally:
        # Allow some time for operations to complete
        stop_timeout = max(2.0, 3.0 / (options["ops_per_sec"] or 1))
        end_time = stop_clients(exit_event, child_procs, options["clients"], stop_timeout)

        while not stat_queue.empty():
            stat = call_ignore_interrupt(stat_queue.get)
            stats["count"] += 1
            stats["latency"] += stat["latency"]
            for error in stat.get("errors", []):
                stats["errors"][error] = stats["errors"].get(error, 0) + 1

        print_stats(stats, start_time, end_time, options["clients"])

if __name__ == "__main__":
    parser = OptionParser()
    parser.add_option("--seed", dest="seed", metavar="FLOAT", default=None, type="float")
    parser.add_option("--table", dest="db_table", metavar="DB.TABLE", default=None, type="string")
    parser.add_option("--ops-per-sec", dest="ops_per_sec", metavar="NUMBER", default=100, type="float")
    parser.add_option("--ops-per-conn", dest="ops_per_conn", metavar="NUMBER", default=0, type="int")
    parser.add_option("--clients", dest="clients", metavar="CLIENTS", default=64, type="int")
    parser.add_option("--workload", "-w", dest="workload", metavar="WORKLOAD", default=None, type="string")
    parser.add_option("--host", dest="hosts", metavar="HOST:PORT", action="append", default=[], type="string")
    parser.add_option("--quiet", dest="quiet", action="store_true", default=False)
    parser.add_option("--duration", dest="duration", type="int", default=0)
    parser.add_option("--op-count", dest="op_count", type="int", default=0)
    parser.add_option("--ignore-errors", dest="ignore_errors", action="store_true", default=False)
    (parsed_options, args) = parser.parse_args()
    options = {
        "clients": parsed_options.clients,
        "ops_per_sec": parsed_options.ops_per_sec,
        "ops_per_conn": parsed_options.ops_per_conn,
        "quiet": parsed_options.quiet,
        "hosts": [],
        "duration": parsed_options.duration,
        "op_count": parsed_options.op_count,
        "ignore_errors": parsed_options.ignore_errors
    }

    if len(args) != 0:
        print "no positional arguments supported"
        exit(1)

    # Parse out host/port pairs
    for host_port in parsed_options.hosts:
        (host, port) = host_port.split(":")
        options["hosts"].append((host, int(port)))
    if len(options["hosts"]) == 0:
        options["hosts"].append(("localhost", 28015))

    if parsed_options.workload is None:
        print "no workload specified"
        exit(1)

    # Get table name, and make sure it exists on the server
    if parsed_options.db_table is None:
        options["db"] = os.environ.get('DB_NAME', 'test')
        options["table"] = os.environ.get('TABLE_NAME', 'stress')
    else:
        options["db"], options["table"] = parsed_options.db_table.split(".")

    with r.connect(options["hosts"][0][0], options["hosts"][0][1]) as connection:
        if options["db"] not in r.db_list().run(connection):
            r.db_create(options["db"]).run(connection)

        if options["table"] not in r.db(options["db"]).table_list().run(connection):
            r.db(options["db"]).table_create(options["table"]).run(connection)
        else:
            # Using existing table
            pass

    # Parse out workload info
    # Add the stress_workload subdirectory to the import search path
    sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), 'stress_workloads')))
    options["workload"] = __import__(parsed_options.workload).Workload(options)
    options["seed"] = parsed_options.seed

    stress_controller(options)
